Visualize state mentions reddit data

# Create interactive plots that show average sentiment of state mentions and state sentiment by state subreddit
import plotly.express as px
import numpy as np
import pandas as pd
import re

# Load data
data = pd.read_csv('tmp/nlp_results.csv')

# Group data by state mentioned
grouped_dat = data.groupby(['states_mentioned']).agg(Sentiment=('sent', 'mean')).reset_index()

# Add atate abbr
states = {
        'AK': 'Alaska',
        'AL': 'Alabama',
        'AR': 'Arkansas',
        'AS': 'American Samoa',
        'AZ': 'Arizona',
        'CA': 'California',
        'CO': 'Colorado',
        'CT': 'Connecticut',
        'DC': 'District of Columbia',
        'DE': 'Delaware',
        'FL': 'Florida',
        'GA': 'Georgia',
        'GU': 'Guam',
        'HI': 'Hawaii',
        'IA': 'Iowa',
        'ID': 'Idaho',
        'IL': 'Illinois',
        'IN': 'Indiana',
        'KS': 'Kansas',
        'KY': 'Kentucky',
        'LA': 'Louisiana',
        'MA': 'Massachusetts',
        'MD': 'Maryland',
        'ME': 'Maine',
        'MI': 'Michigan',
        'MN': 'Minnesota',
        'MO': 'Missouri',
        'MP': 'Northern Mariana Islands',
        'MS': 'Mississippi',
        'MT': 'Montana',
        'NA': 'National',
        'NC': 'North Carolina',
        'ND': 'North Dakota',
        'NE': 'Nebraska',
        'NH': 'New Hampshire',
        'NJ': 'New Jersey',
        'NM': 'New Mexico',
        'NV': 'Nevada',
        'NY': 'New York',
        'OH': 'Ohio',
        'OK': 'Oklahoma',
        'OR': 'Oregon',
        'PA': 'Pennsylvania',
        'PR': 'Puerto Rico',
        'RI': 'Rhode Island',
        'SC': 'South Carolina',
        'SD': 'South Dakota',
        'TN': 'Tennessee',
        'TX': 'Texas',
        'UT': 'Utah',
        'VA': 'Virginia',
        'VI': 'Virgin Islands',
        'VT': 'Vermont',
        'WA': 'Washington',
        'WI': 'Wisconsin',
        'WV': 'West Virginia',
        'WY': 'Wyoming'
}

def best_match(x):
    if len(x) == 2: # Try another way for 2-letter codes
        for a,n in states.items():
            if len(n.split()) == 2:
                if "".join([c[0] for c in n.split()]).lower() == x.lower():
                    return a.lower()
    new_rx = re.compile(r"\w*".join([ch for ch in x]), re.I)
    for a,n in states.items():
        if new_rx.match(n):
            return a
        
grouped_dat['State'] = grouped_dat['states_mentioned'].apply(lambda x: best_match(x))

# Print data
print(grouped_dat.head(10))

# Create map
fig = px.choropleth(grouped_dat, locations='State',
                    locationmode="USA-states", color='Sentiment', scope="usa",
                    color_continuous_scale="RdBu_r",
                   hover_data={'Sentiment':':.2f'})
fig.update_layout(width=800,
                 height=500)
fig.show()
  states_mentioned  Sentiment State
0          Alabama  -0.171206    AL
1           Alaska   0.339683    AK
2          Arizona   0.121629    AZ
3         Arkansas  -0.046693    AR
4       California  -0.135839    CA
5         Colorado   0.199318    CO
6      Connecticut   0.070433    CT
7         Delaware   0.301779    DE
8          Florida  -0.207535    FL
9          Georgia  -0.098915    GA
# Create an interactive matrix that shows sentiment of state mentions in subreddit attributable to state
data_mat = pd.read_csv('tmp/state_matrix_nlp.csv')
data_mat = data_mat.round(2)
fig = px.imshow(data_mat, color_continuous_scale='RdBu_r', origin='lower',
               labels=dict(x="State Mentioned", y="Subreddit", color="Sentiment"),
               x = ['Alabama', 'Alaska', 'Arizona', 'Arkansas', 'California', 'Colorado', 'Connecticut', 'Delaware', 'Florida', 'Georgia', 'Hawaii', 'Idaho', 'Illinois', 'Indiana', 'Iowa', 'Kansas', 'Kentucky', 'Louisiana', 'Maine', 'Maryland', 'Massachusetts', 'Michigan', 'Minnesota', 'Mississippi', 'Missouri', 'Montana', 'Nebraska', 'Nevada', 'New Hampshire', 'New Jersey', 'New Mexico', 'New York', 'North Carolina', 'North Dakota', 'Ohio', 'Oklahoma', 'Oregon', 'Pennsylvania', 'Rhode Island', 'South Carolina', 'South Dakota', 'Tennessee', 'Texas', 'Utah', 'Vermont', 'Virginia', 'Washington', 'West Virginia', 'Wisconsin', 'Wyoming'],
               y = ['Alabama', 'Alaska', 'Arizona', 'Arkansas', 'California', 'Colorado', 'Connecticut', 'Delaware', 'Florida', 'Georgia', 'Hawaii', 'Idaho', 'Illinois', 'Indiana', 'Iowa', 'Kansas', 'Kentucky', 'Louisiana', 'Maine', 'Maryland', 'Massachusetts', 'Michigan', 'Minnesota', 'Mississippi', 'Missouri', 'Montana', 'Nebraska', 'Nevada', 'New Hampshire', 'New Jersey', 'New Mexico', 'New York', 'North Carolina', 'North Dakota', 'Ohio', 'Oklahoma', 'Oregon', 'Pennsylvania', 'Rhode Island', 'South Carolina', 'South Dakota', 'Tennessee', 'Texas', 'Utah', 'Vermont', 'Virginia', 'Washington', 'West Virginia', 'Wisconsin', 'Wyoming'])
fig.update_layout(width=800,
                 height=800)
fig.update_layout(yaxis = dict(tickfont = dict(size=8)))
fig.update_layout(xaxis = dict(tickfont = dict(size=8)))
#fig.update_layout(template='plotly_dark', plot_bgcolor='rgba(169, 169, 169, 0)', paper_bgcolor='rgba(169, 169, 169, 0)')
fig.show()
fig.write_html("../../website-source/plots/interactive_subreddit.html")